{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Klhdy8pnk5J8"
      },
      "source": [
        "**A tool to visualize the segmentation model inference output.**\\\n",
        "This tool is used verify that the exported tflite can produce expected segmentation results.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-vGHZSPWXbyu"
      },
      "outputs": [],
      "source": [
        "MODEL='gs://**/placeholder_for_edgetpu_models/autoseg/segmentation_search_edgetpu_s_not_fused.tflite'#@param\n",
        "IMAGE_HOME = 'gs://**/PS_Compare/20190711'#@param\n",
        "# Relative image file names separated by comas.\n",
        "TEST_IMAGES = 'ADE_val_00001626.jpg,ADE_val_00001471.jpg,ADE_val_00000557.jpg'#@param\n",
        "IMAGE_WIDTH = 512 #@param\n",
        "IMAGE_HEIGHT = 512 #@param"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zzhF1ASDkxTU"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "from PIL import Image as PILImage\n",
        "import matplotlib.pyplot as plt\n",
        "from scipy import ndimage"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AXaJgLg1ml16"
      },
      "outputs": [],
      "source": [
        "# This block creates local copies of /cns and /x20 files.\n",
        "TEST_IMAGES=','.join([IMAGE_HOME+'/'+image for image in TEST_IMAGES.split(',')])\n",
        "\n",
        "# The tflite interpreter only accepts model in local path.\n",
        "def local_copy(awaypath):\n",
        "  localpath = '/tmp/' + awaypath.split('/')[-1]\n",
        "  !rm -f {localpath}\n",
        "  !fileutil cp -f {awaypath} {localpath}\n",
        "  !ls -lht {localpath}\n",
        "  %download_file {localpath}\n",
        "  return localpath\n",
        "\n",
        "IMAGES = [local_copy(image) for image in TEST_IMAGES.split(',')]\n",
        "MODEL_COPY=local_copy(MODEL)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KhS1lOrxHp5C"
      },
      "outputs": [],
      "source": [
        "# Creates a 6px wide boolean edge mask to highlight the segmentation.\n",
        "def edge(mydata):\n",
        "  mydata = mydata.reshape(512, 512)\n",
        "  mydatat = mydata.transpose([1, 0])\n",
        "  mydata = np.convolve(mydata.reshape(-1), [-1, 0, 1], mode='same').reshape(512, 512)\n",
        "  mydatat = np.convolve(mydatat.reshape(-1), [-1, 0, 1], mode='same').reshape(512, 512).transpose([1, 0])\n",
        "  mydata = np.maximum((mydata != 0).astype(np.int8), (mydatat != 0).astype(np.int8))\n",
        "  mydata = ndimage.binary_dilation(mydata).astype(np.int8)\n",
        "  mydata = ndimage.binary_dilation(mydata).astype(np.int8)\n",
        "  mydata = ndimage.binary_dilation(mydata).astype(np.int8)\n",
        "  return mydata"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GdlsbiVqL5JZ"
      },
      "outputs": [],
      "source": [
        "def run_model(input_data):\n",
        "  _input_data = input_data\n",
        "  _input_data = (_input_data-128).astype(np.int8)\n",
        "  # Load the tflite model and allocate tensors.\n",
        "  interpreter_x = tf.lite.Interpreter(model_path=MODEL_COPY)\n",
        "  interpreter_x.allocate_tensors()\n",
        "  # Get input and output tensors.\n",
        "  input_details = interpreter_x.get_input_details()\n",
        "  output_details = interpreter_x.get_output_details()\n",
        "  interpreter_x.set_tensor(input_details[0]['index'], _input_data)\n",
        "  interpreter_x.invoke()\n",
        "  output_data = interpreter_x.get_tensor(output_details[0]['index'])\n",
        "  return output_data.reshape((512, 512, 1))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1mot5M_nl5P7"
      },
      "outputs": [],
      "source": [
        "# Set visualization wind sizes.\n",
        "fig, ax = plt.subplots(max(len(IMAGES),2), 3)\n",
        "fig.set_figwidth(30)\n",
        "fig.set_figheight(10*max(len(IMAGES),2))\n",
        "\n",
        "# Read and test image.\n",
        "for r, image in enumerate(IMAGES):\n",
        "  im = PILImage.open(image).convert('RGB')\n",
        "  min_dim=min(im.size[0], im.size[1])\n",
        "  im = im.resize((IMAGE_WIDTH*im.size[0] // min_dim, IMAGE_HEIGHT*im.size[1] // min_dim))\n",
        "  input_data = np.expand_dims(im, axis=0)\n",
        "  input_data = input_data[:, :IMAGE_WIDTH,:IMAGE_HEIGHT]\n",
        "  ax[r, 0].imshow(input_data.reshape([512, 512, 3]).astype(np.uint8))\n",
        "  ax[r, 0].set_title('Original')\n",
        "  ax[r, 0].grid(False)\n",
        "\n",
        "  # Test the model on random input data.\n",
        "  output_data = run_model(input_data)\n",
        "  ax[r, 1].imshow(output_data, vmin = 0, vmax = 32)\n",
        "  ax[r, 1].set_title('Segmentation')\n",
        "  ax[r, 1].grid(False)\n",
        "\n",
        "  output_data = np.reshape(np.minimum(output_data, 32), [512,512])\n",
        "  output_edge = edge(output_data).reshape(512,512, 1)\n",
        "  output_data = np.stack([output_data%3, (output_data//3)%3, (output_data//9)%3], axis = -1)\n",
        "  \n",
        "  output_data = input_data.reshape([512, 512, 3]).astype(np.float32) * (1-output_edge) + output_data * output_edge * 255\n",
        "  ax[r, 2].imshow(output_data.astype(np.uint8), vmin = 0, vmax = 256)\n",
        "  ax[r, 2].set_title('Segmentation \u0026 original')\n",
        "  ax[r, 2].grid(False)\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//quality/ranklab/experimental/notebook:rl_colab",
        "kind": "private"
      },
      "name": "Inference_visualization_tool.ipynb",
      "private_outputs": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
